-
Notifications
You must be signed in to change notification settings - Fork 39
activation-level disillation #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
great progress! did you freeze everything except the randomly initialized mixers? |
|
Resetting and distilling only one layer, freezing the rest of the model gives satisfactory results:
Note some changes were required to allow loading a pretrained model while freezing certain layers (#394 ) |
| ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, | ||
| ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, | ||
| ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, | ||
| ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: tp2, stp2, stp2_ce4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll probably want to leave these as unimportant and run once in a while, because the testing suite can't really support many distributed runs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Some of the distributed tests are failing currently, so let's leave it as broken for now?
| """ | ||
| Maybe apply activation distillation loss and setup backward hooks | ||
| """ | ||
| mixer_output = hidden_states if bias is None else hidden_states + bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should only be evaluated if needed.
| self.mlp.preprocess(kwargs) | ||
|
|
||
| # TODO: add layer_index | ||
| _activation_distillation_loss_name = "activation_distillation_loss" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to have a layer index in logging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree! This involves a bit more changes because FixedBlockSequence currently assumes that all the blocks have the same loss definitions.
We could add this in a second PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd also need to be careful not to drown the logs in these activation losses for each layer
|
Thank you for the reviews! The comments are addressed, could you have another look? @jlamypoirier |



✨ Description
Closes #385
TODOs:
0and gradients as well.Sanity checks:
0loss ✔️. But loss then increases to a small value instead of staying at 0.0loss (orange)With the caveat that distillation seems to experience memory spikes at specific points in training. The actual usage was lower most of the time:
🔍 Type of change
Select all that apply:
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable: